rm(list=ls())
library(tidyverse)
library(dplyr)
library(magrittr)
library(zoo)
library(nnet)

setwd(dirname(rstudioapi::getActiveDocumentContext()$path))

load('Data/Census/census_plc.RData')
cen <- census[,c('place_fips', 'PLACE', 'year', 'pop', 'foreign', 'nhis.whi','nhis.bla', 'hisp', 'bla', 'asi')] %>%
  mutate(year = as.numeric(year)) %>%
  mutate_at(vars(foreign:asi), ~ ifelse(is.na(pop) | pop == 0, NA, .x)) %>%
  arrange(place_fips, year) %>%
  group_by(place_fips) %>%
  mutate_at(vars(pop:asi), list(na.approx), na.rm = F) %>%
  drop_na(pop) %>%
  mutate(place_fips = ifelse(place_fips == 1836010, 1836003, place_fips))
rm(census)

t <- readRDS('Data/Prediction/img5_weighted_opencv_cov2f_nofw_level.rds') %>%
  filter(race.mt != 'non') %>%
  drop_na(race.mt, race.s, race.gs, race.fbisg, race.ff, race.lstm_wk, race.hybrid) %>%
  mutate(cyr = floor(year/10) * 10 ,
         cyr = ifelse(cyr <= 1970, 1970, ifelse(cyr >= 2010, 2010, cyr)),
         place_fips = ifelse(place_fips==2148000 & cyr==2010, 2148006, #Louisville, KY
                             ifelse(place_fips==1517000 & cyr==2010, 1571550, #Honolulu, HI
                                    ifelse(place_fips== 1836000, 1836003, #Indianapolis, IN
                                           place_fips)))) %>%
  inner_join(cen %>% select(-PLACE), by = c('place_fips', 'cyr'='year')) %>%
  mutate(pct_min = (hisp + bla + asi) / pop, 
         pct_bla = bla / pop,
         pct_his = hisp / pop,
         pct_asi = asi / pop,
         pct_foreign = foreign / pop,
         office_consolidated = ifelse(office_consolidated == 'Board Member', 'City Council', office_consolidated),
         cyr = ifelse(year < 1960, 1950, ifelse(year < 1970, 1960, cyr)),
         inc = as.factor(ifelse(is.na(incumbent), 'missing', incumbent)),
         atlarge = as.numeric(grepl('large', district, ignore.case = T)),
         pid = ifelse(is.na(pid_final), 'Unk', pid_final))

fit     <- list()
results <- NULL  
for (r in c('mt', 's', 'gs', 'fbisg', 'ff', 'hybrid', 'lstm_wk')){
  print(r)
  Y = paste0('race.', r)
  regdf <- t
  regdf$Y <- t[[paste0(Y)]]
  racelevel <- c('whi', 'asi', 'bla', 'his', 'oth')
  if (length(unique(regdf$Y)) < 5)  racelevel <- c('whi', 'asi', 'bla', 'his')
  regdf$Y <- factor(regdf$Y, levels = racelevel)
  set.seed(8901324)
  fit[[paste0(Y,'.1')]] <- multinom(Y ~ pct_asi + pct_bla + pct_his, data = regdf, maxit = 1000, model = T)
  fit[[paste0(Y,'.5')]] <- multinom(Y~ pct_asi + pct_bla + pct_his + inc + as.factor(office_consolidated) + atlarge + as.factor(pid), data = regdf, maxit = 1000, model = T)
  fit[[paste0(Y,'.6')]] <- multinom(Y~ pct_asi + pct_bla + pct_his + inc + as.factor(office_consolidated) + atlarge + as.factor(pid) + as.factor(state), data = regdf, maxit = 1000, model = T)
  fit[[paste0(Y,'.7')]] <- multinom(Y~ pct_asi + pct_bla + pct_his + inc + as.factor(office_consolidated) + atlarge + as.factor(pid) + as.factor(state) + as.factor(cyr), data = regdf, maxit = 1000, model = T)
  
  # SUMMARIZING
  for (M in c(1,5,6,7)){
    beta = as.data.frame(summary(fit[[paste0(Y, '.', M)]])$coefficients) %>% select(matches('pct_|inc|office|pid|large'))     %>% rownames_to_column()
    se   = as.data.frame(summary(fit[[paste0(Y, '.', M)]])$standard.errors) %>% select(matches('pct_|inc|office|pid|large'))  %>% rownames_to_column()
    obs  = length(residuals(fit[[paste0(Y, '.', M)]])) / length(racelevel)
    names(se) <- names(beta)
    
    results = bind_rows(results,
                        beta %>% rename_at(vars(matches('pct|inc|off|pid|large')), ~ paste0(.x, '_b')) %>%
                          left_join(se %>% rename_at(vars(matches('pct|inc|off|pid|large')), ~ paste0(.x, '_se'))) %>%
                          mutate(df      = 'img5_weighted_opencv_cov2f_nofw_level',
                                 Yvar    = r, 
                                 modno   = M,
                                 N       = obs) %>%
                          rename(outcome = rowname)
    )
    rm(beta, se, obs) 
  }
  rm(Y, racelevel, regdf)
}

save(results, file = 'Output/est_sum_img5_weighted_opencv_cov2f_nofw_level.rdata')
save(fit, file = 'Output/est_img5_weighted_opencv_cov2f_nofw_level.rdata')


# Figure 4
df <- results %>%
  mutate(x = case_when(Yvar=='mt' ~ 'True',
                       Yvar=='s' ~ 'BSO',
                       Yvar=='gs' ~ 'BISG',
                       Yvar=='hybrid' ~ 'Hybrid',
                       Yvar=='lstm_wk' ~ 'LSTM',
                       Yvar=='fbisg' ~ 'fBISG',
                       Yvar=='ff' ~ 'Image'))
df$x <- factor(df$x, levels = c("True",  "BSO", "BISG", "fBISG", "LSTM", "Image", "Hybrid"))
df <- df %>% select(-ends_with('_se')) %>% pivot_longer(cols = matches('^pct_.*b$'), names_to = 'xvar', values_to = 'b') %>%
  mutate(xvar = gsub('_b$', '_se', xvar)) %>%
  left_join(df %>% select(-ends_with('_b')) %>% pivot_longer(cols = matches('^pct_.*se$'), names_to = 'xvar', values_to = 'se'))
for (i in c('Black', 'Asian', 'Hispanic')){
  df$outcome <- ifelse(df$outcome == substr(tolower(i), 1, 3), i, df$outcome)
  df$xvar    <- ifelse(grepl(substr(tolower(i), 1, 3), df$xvar), i, df$xvar)
}
df %<>% filter(outcome == xvar) 

t <- df[df$modno == 7,] %>%
  group_by(modno, outcome, xvar) %>%
  mutate(trueb  = mean(ifelse(x == 'True', b, NA),  na.rm=T),
         truese = mean(ifelse(x == 'True', se, NA), na.rm=T)) %>%
  filter(x != 'True') %>%
  mutate(grp = as.numeric(x)-1)
true <- t %>% distinct(modno, trueb, truese, outcome, xvar) %>% arrange(modno) %>% mutate(xmin = 0.25, xmax = 6.75)


  ggplot() +
    geom_rect(data = true, aes(xmin = xmin, xmax = xmax, ymin = trueb - 2* truese, ymax = trueb + 2 * truese),
              fill = 'gray50', color = 'gray85', alpha = 0.3) +
    geom_rect(data = true, aes(xmin = xmin, xmax = xmax, ymin = trueb, ymax = trueb),
              fill = 'black', color = 'black') +
    geom_point(data = t, aes(x = grp, y = b,  color = x), position = position_dodge(width = 0.5)) +
    geom_errorbar(data = t, aes(x = grp, ymin = b - 2 * se, ymax = b + 2 * se, color = x),
                  position = position_dodge(width = 0.5), width = 0.1) +
    facet_wrap(vars(outcome), nrow = 1, scale=  'free') +
    labs(x = '', y = expression(hat(beta)[k])) +
    scale_color_manual(name = "", values = c('#38761d', '#55b32c', '#234912', 'gray45', 'gray70', 'red')) +
    scale_shape_manual(name = "", values = c('#38761d', '#55b32c', '#234912', 'gray45', 'gray70', 'red')) +
    scale_x_continuous(expand = c(0,0),
                       breaks = 1:6,
                       labels = c("BSO", "BISG", "fBISG", "LSTM", "Image", "Hybrid")) +
    guides(colour = guide_legend(nrow = 1)) +
    theme_bw() + 
    theme(legend.position = "bottom",
          legend.title = element_blank(),
          legend.text = element_text(size = 8),
          legend.key = element_blank(),
          legend.background = element_rect(fill = "white", colour = "black", size = 0.2),
          axis.ticks.x = element_blank(),
          plot.title = element_text(size=8, hjust = 0.5),
          panel.grid.major.x = element_blank(),
          panel.grid.minor.y = element_blank(),
          legend.box.margin=margin(-15,0,0,0),
          axis.text.x = element_text(angle=45, hjust=1),
          strip.text.x = element_text(size = 8)) -> p

ggsave('Figures/Figure 4.png', p, width = 7, height = 3.1, dpi = 600)  
  
